{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n\n# Compiling a Transformer using torch.compile and TensorRT\n\nThis interactive script is intended as a sample of the `torch_tensorrt.dynamo.compile` workflow on a transformer-based model.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Imports and Model Definition\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch\nimport torch_tensorrt\nfrom transformers import BertModel"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Initialize model with float precision and sample inputs\nmodel = BertModel.from_pretrained(\"bert-base-uncased\").eval().to(\"cuda\")\ninputs = [\n    torch.randint(0, 2, (1, 14), dtype=torch.int32).to(\"cuda\"),\n    torch.randint(0, 2, (1, 14), dtype=torch.int32).to(\"cuda\"),\n]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Optional Input Arguments to `torch_tensorrt.dynamo.compile`\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Enabled precision for TensorRT optimization\nenabled_precisions = {torch.float}\n\n# Whether to print verbose logs\ndebug = True\n\n# Workspace size for TensorRT\nworkspace_size = 20 << 30\n\n# Maximum number of TRT Engines\n# (Lower value allows more graph segmentation)\nmin_block_size = 3\n\n# Operations to Run in Torch, regardless of converter support\ntorch_executed_ops = {}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Compilation with `torch_tensorrt.dynamo.compile`\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Build and compile the model with torch.compile, using tensorrt backend\noptimized_model = torch_tensorrt.dynamo.compile(\n    model,\n    inputs,\n    enabled_precisions=enabled_precisions,\n    debug=debug,\n    workspace_size=workspace_size,\n    min_block_size=min_block_size,\n    torch_executed_ops=torch_executed_ops,\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Equivalently, we could have run the above via the convenience frontend, as so:\n`torch_tensorrt.compile(model, ir=\"dynamo_compile\", inputs=inputs, ...)`\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Inference\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Does not cause recompilation (same batch size as input)\nnew_inputs = [\n    torch.randint(0, 2, (1, 14), dtype=torch.int32).to(\"cuda\"),\n    torch.randint(0, 2, (1, 14), dtype=torch.int32).to(\"cuda\"),\n]\nnew_outputs = optimized_model(*new_inputs)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Does cause recompilation (new batch size)\nnew_inputs = [\n    torch.randint(0, 2, (4, 14), dtype=torch.int32).to(\"cuda\"),\n    torch.randint(0, 2, (4, 14), dtype=torch.int32).to(\"cuda\"),\n]\nnew_outputs = optimized_model(*new_inputs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Cleanup\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Finally, we use Torch utilities to clean up the workspace\ntorch._dynamo.reset()\n\nwith torch.no_grad():\n    torch.cuda.empty_cache()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}